-
-
Notifications
You must be signed in to change notification settings - Fork 8.8k
[Model] Add support for Jina Embeddings V4 #20802
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @sigridjineth, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly expands vLLM's capabilities by integrating native support for the Jina Embeddings V4 multimodal model. My work ensures that users can now efficiently generate embeddings for both text and image inputs, benefiting from highly optimized processing, robust error handling, and comprehensive testing to guarantee accuracy and stability.
Highlights
- Jina Embeddings V4 Integration: I've added full, production-ready support for the Jina Embeddings V4 model (
jinaai/jina-embeddings-v4-vllm-retrieval
), enabling both text and image multimodal embeddings within vLLM. - Optimized Multimodal Pooling: The implementation includes thread-safe, token-type-aware pooling, leveraging optimized Triton CUDA kernels for efficient vision token extraction, with a robust PyTorch fallback.
- Robustness and Observability: I've incorporated comprehensive error handling, including graceful fallback mechanisms and OOM recovery during pooling, alongside full observability integration for performance metrics.
- Comprehensive Testing & Examples: New unit tests, a validation benchmark against HuggingFace, and an offline inference example script have been added to ensure accuracy and demonstrate usage.
- Documentation Updates: The model has been registered in the model registry, added to the supported models documentation, and a detailed implementation guide has been provided.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This PR adds production-ready support for the Jina Embeddings V4 model. I've identified a bug in the tests, a performance issue in the core implementation, and some areas for code improvement in the example and validation scripts.
34f3e7f
to
d7d6b60
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for contributing! Can you add this model to the test registry and supported models page?
|
||
|
||
# Triton kernel for optimized vision token extraction | ||
if HAS_TRITON: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would provide Triton performance benchmarks after finshing up some tasks in the pr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this triton kernel is only used in pooler, I think the performance improvement will be very little. But it would be best to have a performance benchmarks first.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you perform benchmarking on this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sigridjineth I did some benchmarks between triton kernels and torch native implementation on RTX3090, but found that triton kernel can be much slower when image seq_len is quite long, which can be normal image input case for Qwen2-VL like model:
Benchmark results
Image sequence length: 512, Text sequence length: 2048, Number of images: 1
-- triton vision pooling = 0.08771181106567383
-- native vision pooling = 0.05670571327209473
Image sequence length: 1024, Text sequence length: 2048, Number of images: 1
-- triton vision pooling = 0.10277390480041504
-- native vision pooling = 0.03438115119934082
Image sequence length: 8192, Text sequence length: 2048, Number of images: 1
-- triton vision pooling = 0.3178141117095947
-- native vision pooling = 0.07503867149353027
Image sequence length: 16384, Text sequence length: 2048, Number of images: 1
-- triton vision pooling = 0.5705935955047607
-- native vision pooling = 0.11778688430786133
Image sequence length: 512, Text sequence length: 2048, Number of images: 2
-- triton vision pooling = 0.09008479118347168
-- native vision pooling = 0.03199028968811035
Image sequence length: 1024, Text sequence length: 2048, Number of images: 2
-- triton vision pooling = 0.10735464096069336
-- native vision pooling = 0.03523516654968262
Image sequence length: 8192, Text sequence length: 2048, Number of images: 2
-- triton vision pooling = 0.3502342700958252
-- native vision pooling = 0.0757303237915039
Image sequence length: 16384, Text sequence length: 2048, Number of images: 2
-- triton vision pooling = 0.6468491554260254
-- native vision pooling = 0.12034487724304199
Image sequence length: 512, Text sequence length: 2048, Number of images: 4
-- triton vision pooling = 0.09511590003967285
-- native vision pooling = 0.03257870674133301
Image sequence length: 1024, Text sequence length: 2048, Number of images: 4
-- triton vision pooling = 0.11696052551269531
-- native vision pooling = 0.03539228439331055
Image sequence length: 8192, Text sequence length: 2048, Number of images: 4
-- triton vision pooling = 0.4277994632720947
-- native vision pooling = 0.07425379753112793
Image sequence length: 16384, Text sequence length: 2048, Number of images: 4
-- triton vision pooling = 0.8103950023651123
-- native vision pooling = 0.11885881423950195
Any idea about this? The benchmark script can be found here: https://gist.github.com/Isotr0py/eef7470ff176a28ac40340b883cf1abe
|
||
|
||
# Triton kernel for optimized vision token extraction | ||
if HAS_TRITON: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this triton kernel is only used in pooler, I think the performance improvement will be very little. But it would be best to have a performance benchmarks first.
@triton.jit | ||
def extract_vision_tokens_kernel( | ||
hidden_states_ptr, | ||
token_ids_ptr, | ||
output_ptr, | ||
seq_start, | ||
seq_len, | ||
hidden_size, | ||
vision_start_id: tl.constexpr, | ||
vision_end_id: tl.constexpr, | ||
BLOCK_SIZE: tl.constexpr, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like putting triton kernel in model implementation, we should move this to pooler.py
or somewhere else if the performance improvement is significant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done caea1fe
4eb5e88
to
caea1fe
Compare
@Isotr0py @DarkLight1337 do review if more changes needed if you think so |
Sorry for the delay, can you merge from main and fix pre-commit? |
Address DarkLight1337's review feedback: - Set logits_processing_needs_token_ids=True for V1 compatibility in both "embed" and "encode" tasks - Support "encode" task by returning PoolingParams() instead of None - Update log message from "thread-safe pooling" to "vision-aware pooling" to better reflect the actual functionality - Remove unused seq_ids variable from _extract_token_ids_safe method These changes ensure proper V1 compatibility and cleaner code structure. Signed-off-by: Sigrid Jin (Sionic AI) <sigrid@sionic.ai>
vllm/model_executor/layers/pooler.py
Outdated
|
||
for i in range(seq_len): | ||
token_id = tl.load(token_ids_ptr + seq_start + i) | ||
if token_id >= vision_start_id and token_id <= vision_end_id: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if token_id >= vision_start_id and token_id <= vision_end_id: | |
if token_id in (vision_start_id, vision_end_id): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually now that I think of it, this isn't quite correct? The start index should be the first item that equals vision_start_id
, and then all subsequent tokens (regardless of ID) are included until vision_end_id
is found
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to your code, you only selected the tokens corresponding to vision_start_id
and vision_end_id
, but not the tokens in between them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See how the official code handles it: https://huggingface.co/jinaai/jina-embeddings-v4/blob/main/modeling_jina_embeddings_v4.py#L228
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you fix this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @DarkLight1337, thanks for highlighting this. reflected again and got that your assessment is correct. there's a positional indexing mistake in the implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the problem was I think it incorrectly selects only tokens exactly matching the vision_start_id and vision_end_id.
It fails to select intermediate tokens between these markers because it uses direct ID matching instead of positional masking. I have created on the commit to introduce the dedicated VisionPooler
class that finds the positions of vision_start_id
and vision_end_id
via torch.where
method. the goal is to ensure pooling the entire positional range.
would like to get your feedback on this 5114a3c
Implement efficiency improvements suggested by DarkLight1337: - Consolidate get_pooling_params method for "embed" and "encode" tasks - Pre-compute vision token IDs tensor in constructor - Replace range checks with torch.isin for more efficient vision token detection at lines 209-210 and 261-262 This reduces redundant code and improves performance when checking for vision tokens by using optimized tensor operations. Signed-off-by: Sigrid Jin (Sionic AI) <sigrid@sionic.ai>
|
||
def extract_embeddings(output): | ||
"""Extract embeddings based on token type.""" | ||
if VISION_START_TOKEN_ID in output.prompt_token_ids: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should test the whole embedding tensor against HF to avoid these kinds of mistakes
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: | ||
"""Return pooling params for embedding task.""" | ||
if task == "embed" or task == "encode": | ||
return PoolingParams(logits_processing_needs_token_ids=True) | ||
|
||
# The equalities are split up to keep mypy happy | ||
if task == "classify" or task == "score": | ||
return None | ||
|
||
assert_never(task) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you merge from main and then apply this update? (Need to update the imports accordingly as well)
def get_pooling_params(self, task: PoolingTask) -> Optional[PoolingParams]: | |
"""Return pooling params for embedding task.""" | |
if task == "embed" or task == "encode": | |
return PoolingParams(logits_processing_needs_token_ids=True) | |
# The equalities are split up to keep mypy happy | |
if task == "classify" or task == "score": | |
return None | |
assert_never(task) | |
def get_pooling_updates( | |
self, | |
task: PoolingTask, | |
) -> Optional[PoolingParamsUpdate]: | |
# The equalities are split up to keep mypy happy | |
if task == "encode" or task == "embed": | |
return PoolingParamsUpdate(requires_token_ids=True) | |
if task == "classify" or task == "score": | |
return None | |
assert_never(task) |
def _extract_token_ids_safe( | ||
self, pooling_metadata: PoolingMetadata) -> list[array]: | ||
"""Safely extract token IDs from pooling metadata.""" | ||
token_ids_list: list[array] = [] | ||
try: | ||
if isinstance(pooling_metadata, V1PoolingMetadata): | ||
# For V1, we get token IDs directly | ||
for i, num in enumerate(pooling_metadata.prompt_lens): | ||
token_ids = pooling_metadata.prompt_token_ids[ | ||
i, :num].tolist() | ||
token_ids_list.append(array('l', token_ids)) | ||
|
||
return token_ids_list | ||
|
||
# For V0, we extract from seq_groups and seq_data | ||
for seq_group, _ in pooling_metadata.seq_groups: | ||
for seq_id in seq_group: | ||
if seq_id not in pooling_metadata.seq_data: | ||
logger.warning("Sequence %s not found in seq_data", | ||
seq_id) | ||
continue | ||
|
||
seq_data = pooling_metadata.seq_data[seq_id] | ||
|
||
# Get prompt token IDs safely | ||
if hasattr(seq_data, 'prompt_token_ids_array'): | ||
token_ids = seq_data.prompt_token_ids_array | ||
elif hasattr(seq_data, '_prompt_token_ids'): | ||
token_ids = seq_data._prompt_token_ids | ||
else: | ||
logger.warning("No token IDs found for sequence %s", | ||
seq_id) | ||
continue | ||
|
||
token_ids_list.append(token_ids) | ||
|
||
return token_ids_list | ||
|
||
except Exception as e: | ||
logger.error( | ||
"Error extracting token IDs: %s. " | ||
"Extracted %d sequences before failure", e, | ||
len(token_ids_list)) | ||
raise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The latest main now has get_prompt_token_ids
in pooler.py
which can replace this functionality (but note that it outputs torch.Tensor
instead of array.array
)
Signed-off-by: Sigrid Jin (Sionic AI) <sigrid@sionic.ai>
6b31c66
to
5114a3c
Compare
@@ -32,6 +32,7 @@ class PoolingType(IntEnum): | |||
CLS = 2 | |||
STEP = 3 | |||
MEAN = 4 | |||
VISION = 5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have created this new type of vision
pooling for PoolingClass.
073e87c
to
c50992a
Compare
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
||
import gc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you test the correctness of the model against the HF implementation?
Signed-off-by: Sigrid Jin (Sionic AI) <sigrid@sionic.ai>
c50992a
to
6b501b2
Compare
@@ -3258,7 +3258,8 @@ def get_limit_per_prompt(self, modality: str) -> int: | |||
class PoolerConfig: | |||
"""Controls the behavior of output pooling in pooling models.""" | |||
|
|||
pooling_type: Optional[str] = None | |||
pooling_type: Optional[Literal["last", "all", "cls", "step", "mean", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually the pooling type here is supposed to be upper case
def from_config(cls, model_config: ModelConfig) -> "VisionPooler": | ||
return cls(model_config) | ||
|
||
def __init__(self, config: ModelConfig): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we pass in the token IDs and hidden size explicitly? In case other models store those attributes in different locations
super().__init__(vllm_config=vllm_config, | ||
prefix=maybe_prefix(prefix, "qwen2_vl")) | ||
|
||
self.pooler = JinaVLPooler(vllm_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not directly use VisionPooler here?
This pull request has merge conflicts that must be resolved before it can be |
Purpose
This PR adds support for the Jina Embeddings V4 model
(
jinaai/jina-embeddings-v4-vllm-retrieval
) in vLLM, enabling multimodal embeddings for text and image inputs.FIX #20463
Test Plan